import typing as tp

import torch
import torch.nn as nn


class ConvLayer(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int,
        activation: str = "GELU",
        dropout_rate: float = 0.0,
    ):
        super().__init__()
        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=(kernel_size - stride) // 2,
        )
        self.drop = nn.Dropout(dropout_rate)
        self.norm = nn.LayerNorm(out_channels)
        self.activ = getattr(nn, activation)()

    def forward(self, x: torch.Tensor):
        """
        Args:
            x: (b, t, c)
        Return:
            x: (b, t, c)
        """
        x = x.transpose(2, 1)
        x = self.conv(x)
        x = x.transpose(2, 1)
        x = self.drop(x)
        x = self.norm(x)
        x = self.activ(x)
        return x


class ResidualConvLayer(nn.Module):
    def __init__(
        self,
        hidden_channels: int,
        n_layers: int = 2,
        kernel_size: int = 5,
        activation: str = "GELU",
        dropout_rate: float = 0.0,
    ):
        super().__init__()
        layers = [
            ConvLayer(
                hidden_channels,
                hidden_channels,
                kernel_size,
                1,
                activation,
                dropout_rate,
            )
            for _ in range(n_layers)
        ]
        self.layers = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor):
        """
        Args:
            x: (b, t, c)
        Returns:
            x: (b, t, c)
        """
        return x + self.layers(x)


class ResidualConvBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        out_channels: int,
        n_layers: int = 2,
        n_blocks: int = 5,
        middle_layer: tp.Optional[nn.Module] = None,
        kernel_size: int = 5,
        activation: str = "GELU",
        dropout_rate: float = 0.0,
    ):
        super().__init__()
        self.in_proj = (
            nn.Conv1d(
                in_channels,
                hidden_channels,
                kernel_size=kernel_size,
                stride=1,
                padding=(kernel_size - 1) // 2,
            )
            if in_channels != hidden_channels
            else nn.Identity()
        )

        self.conv1 = nn.Sequential(
            *[
                ResidualConvLayer(
                    hidden_channels, n_layers, kernel_size, activation, dropout_rate
                )
                for _ in range(n_blocks)
            ]
        )

        if middle_layer is None:
            self.middle_layer = nn.Identity()
        elif isinstance(middle_layer, nn.Module):
            self.middle_layer = middle_layer
        else:
            raise TypeError("unknown middle layer type:{}".format(type(middle_layer)))

        self.conv2 = nn.Sequential(
            *[
                ResidualConvLayer(
                    hidden_channels, n_layers, kernel_size, activation, dropout_rate
                )
                for _ in range(n_blocks)
            ]
        )

        self.out_proj = (
            nn.Conv1d(
                hidden_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=1,
                padding=(kernel_size - 1) // 2,
            )
            if out_channels != hidden_channels
            else nn.Identity()
        )

    def forward(self, x: torch.Tensor, **middle_layer_kwargs):
        """
        Args:
            x: (b, t1, c)
        Return:
            x: (b, t2, c)
        """
        x = self.in_proj(x.transpose(2, 1)).transpose(2, 1)
        x = self.conv1(x)
        if isinstance(self.middle_layer, nn.MaxPool1d) or isinstance(
            self.middle_layer, nn.Conv1d
        ):
            x = self.middle_layer(x.transpose(2, 1)).transpose(2, 1)
        elif isinstance(self.middle_layer, nn.Identity):
            x = self.middle_layer(x)
        else:
            # incase of phoneme-pooling layer
            x = self.middle_layer(x, **middle_layer_kwargs)
        x = self.conv2(x)
        x = self.out_proj(x.transpose(2, 1)).transpose(2, 1)
        return x


class MelReduceEncoder(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        hidden_channels: int = 384,
        reduction_rate: int = 4,
        n_layers: int = 2,
        n_blocks: int = 5,
        kernel_size: int = 3,
        activation: str = "GELU",
        dropout: float = 0.0,
    ):
        super().__init__()
        self.reduction_rate = reduction_rate
        middle_conv = nn.Conv1d(
            in_channels=hidden_channels,
            out_channels=hidden_channels,
            kernel_size=reduction_rate,
            stride=reduction_rate,
            padding=0,
        )
        self.encoder = ResidualConvBlock(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            out_channels=out_channels,
            n_layers=n_layers,
            n_blocks=n_blocks,
            middle_layer=middle_conv,
            kernel_size=kernel_size,
            activation=activation,
            dropout_rate=dropout,
        )

    def forward(self, x: torch.Tensor):
        return self.encoder(x)
